import pandas as pd
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import os
# 定义 计算lx的函数:  udaf_lx
@F.pandas_udf('decimal(17,12)')
def udaf_lx(lx:pd.Series,qx:pd.Series) -> decimal:
    tmp_lx = decimal.Decimal(0)  -- 1
    tmp_qx = decimal.Decimal(0)  -- 0.000615

    for i in range(len(lx)): #2
        if i == 0:
            tmp_lx = decimal.Decimal(lx[i])   -- 1
            tmp_qx = decimal.Decimal(qx[i])   -- 0.000615
        else:
            tmp_lx = (tmp_lx * (1- tmp_qx)).quantize(decimal.Decimal('0.000000000000'))
            tmp_qx = decimal.Decimal(qx[i])

    return  tmp_lx
# 注册
spark.udf.register('udaf_lx',udaf_lx)